import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp 
import math
from collections import deque
import time
import numpy as np
from tqdm import tqdm

from environments.cvrp_env import CVRPEnv
from models.actor_critic import ActorCritic
from utils.data_generator import VRPDataGenerator

class Worker(mp.Process):
    def __init__(self, env_config, global_actor_critic, optimizer, global_episode_count,
                 max_episodes, result_queue, worker_idx, args, device):
        super(Worker, self).__init__()
        self.env = CVRPEnv(num_nodes=env_config['num_nodes'], capacity=env_config['capacity'], depot_idx=env_config['depot_idx'])
        self.global_actor_critic = global_actor_critic
        self.optimizer = optimizer
        self.global_episode_count = global_episode_count
        self.max_episodes = max_episodes
        self.result_queue = result_queue
        self.worker_idx = worker_idx
        self.args = args
        self.device = device

        # 初始化本地ActorCritic模型
        self.local_actor_critic = ActorCritic(
            input_dim=4, # 坐标(2) + 需求(1) + 访问掩码(1)
            embedding_dim=args.embedding_dim,
            hidden_dim=args.hidden_dim,
            num_nodes_max=env_config['num_nodes'] + 1,
            capacity=env_config['capacity']
        )
        self.local_actor_critic.to(self.device)
        self.depot_idx = env_config['depot_idx']

    def run(self):
        while self.global_episode_count.value < self.max_episodes:
            # 从全局模型加载参数
            self.local_actor_critic.load_state_dict(self.global_actor_critic.state_dict())

            coords, demands, capacity, num_customer_nodes = self.args.data_generator.generate_instance()

            # 重新初始化环境
            self.env = CVRPEnv(num_nodes=num_customer_nodes, capacity=capacity, depot_idx=self.depot_idx)
            state = self.env.reset(coords, demands)

            # 准备输入数据
            coords_t = torch.from_numpy(coords).float().unsqueeze(0).to(self.device)
            demands_t = torch.from_numpy(demands).float().unsqueeze(0).to(self.device)
            max_demand_val = demands_t.max() if demands_t.numel() > 0 else torch.tensor(1.0, device=self.device)
            if max_demand_val == 0: max_demand_val = torch.tensor(1.0, device=self.device)

            log_probs_list = []
            values_list = []
            rewards_list = []
            logits_list = []

            done = False
            episode_reward = 0.0
            last_action_idx = torch.tensor([self.env.depot_idx], device=self.device)
            hidden_state = None

            step_count = 0
            max_steps_per_episode = (num_customer_nodes + 1) * 2

            while not done and step_count < max_steps_per_episode:
                # 准备模型输入状态
                model_state_dict = {
                    'coords': coords_t,
                    'demands': demands_t,
                    'current_node': torch.tensor([state['current_node']], device=self.device),
                    'visited_mask': torch.from_numpy(state['visited_mask']).bool().unsqueeze(0).to(self.device),
                    'current_capacity': torch.tensor([state['current_capacity']]).float().unsqueeze(0).to(self.device),
                    'total_distance': torch.tensor([state['total_distance']]).float().unsqueeze(0).to(self.device),
                    'num_vehicles_used': torch.tensor([state['num_vehicles_used']]).float().unsqueeze(0).to(self.device),
                    'depot_idx': self.depot_idx,
                    'max_demand_val': max_demand_val
                }

                action_mask_np = self.env.get_action_mask()
                action_mask = torch.from_numpy(action_mask_np).bool().unsqueeze(0).to(self.device)
                model_state_dict['action_mask'] = action_mask

                logits, value, hidden_state = self.local_actor_critic(
                    model_state_dict,
                    last_action_idx=last_action_idx,
                    hidden_state=hidden_state
                )

                logits_list.append(logits.squeeze(0))
                values_list.append(value.squeeze(0))

                # 应用动作掩码
                final_logits = logits.clone()
                final_logits[action_mask.logical_not()] = -float('inf')

                probs = F.softmax(final_logits, dim=-1)

                # 处理所有动作被掩码的情况
                if torch.all(probs == 0):
                    valid_indices = torch.where(action_mask.squeeze(0))[0]
                    if len(valid_indices) > 0:
                         uniform_probs = torch.ones(len(valid_indices), device=self.device) / len(valid_indices)
                         chosen_index = torch.multinomial(uniform_probs, 1)
                         action = valid_indices[chosen_index].unsqueeze(0)
                         log_prob_selected = F.log_softmax(final_logits, dim=-1).gather(1, action)
                         if log_prob_selected.isinf().any():
                              log_prob_selected = torch.tensor([-1000.0], device=self.device)
                         print(f"Worker {self.worker_idx}: 所有概率为0，选择随机有效动作 {action.item()}")
                    else:
                         print(f"Worker {self.worker_idx}: 错误！没有找到有效动作，强制结束")
                         done = True
                         reward = -1e7
                         action = torch.tensor([[self.depot_idx]], device=self.device)
                         log_prob_selected = torch.tensor([-1000.0], device=self.device)

                else:
                    action = torch.multinomial(probs, 1)
                    log_prob_selected = F.log_softmax(final_logits, dim=-1).gather(1, action)

                log_probs_list.append(log_prob_selected.squeeze(0))

                next_state, reward, done, info = self.env.step(action.item())

                rewards_list.append(reward)
                episode_reward += reward

                state = next_state
                last_action_idx = action.squeeze(0)
                step_count += 1

            if step_count >= max_steps_per_episode and not done:
                 if rewards_list:
                     rewards_list[-1] -= 1000
                 else:
                     rewards_list.append(-1000)
                 episode_reward = sum(rewards_list)
                 done = True

            # 计算回报
            returns = []
            R = 0
            for r in reversed(rewards_list):
                R = r + self.args.gamma * R
                returns.insert(0, R)

            if not returns:
                 returns_t = torch.tensor([], dtype=torch.float32, device=self.device)
            else:
                 returns_t = torch.tensor(returns, dtype=torch.float32).to(self.device)

            if not values_list or not log_probs_list:
                 total_loss = torch.tensor(0.0, device=self.device)
                 print(f"Worker {self.worker_idx}: 回合没有步骤，跳过损失计算")
            else:
                 values_t = torch.stack(values_list)
                 log_probs_t = torch.stack(log_probs_list)

                 min_len = min(values_t.size(0), log_probs_t.size(0), returns_t.size(0))
                 values_t = values_t[:min_len]
                 log_probs_t = log_probs_t[:min_len]
                 returns_t = returns_t[:min_len]

                 advantages = returns_t - values_t.detach()

                 actor_loss = -(log_probs_t * advantages).mean()
                 critic_loss = F.mse_loss(values_t, returns_t.unsqueeze(1))

                 # 计算熵
                 if logits_list:
                     all_step_logits = torch.stack(logits_list).squeeze(1)
                     all_step_logits = all_step_logits[:min_len]
                     all_step_probs = F.softmax(all_step_logits, dim=-1)
                     entropy = -(all_step_probs * torch.log(all_step_probs + 1e-9)).sum(dim=-1).mean()
                 else:
                     entropy = torch.tensor(0.0, device=self.device)

                 total_loss = actor_loss + self.args.value_loss_coeff * critic_loss - self.args.entropy_coeff * entropy

                 self.optimizer.zero_grad()
                 total_loss.backward()
                 torch.nn.utils.clip_grad_norm_(self.local_actor_critic.parameters(), max_norm=1.0)

                 # 将梯度从本地复制到全局
                 for global_param, local_param in zip(self.global_actor_critic.parameters(), self.local_actor_critic.parameters()):
                     if local_param.grad is not None:
                         if global_param.grad is not None:
                             global_param.grad.data.copy_(local_param.grad.data)
                         else:
                             global_param.grad = local_param.grad.clone()

                 self.optimizer.step()

            with self.global_episode_count.get_lock():
                self.global_episode_count.value += 1
                current_episode_val = self.global_episode_count.value

            final_loss = total_loss.item() if 'total_loss' in locals() else 0.0

            self.result_queue.put({
                'episode': current_episode_val,
                'reward': episode_reward,
                'distance': self.env.total_distance,
                'vehicles': self.env.num_vehicles_used,
                'loss': final_loss
            })

class A3CAgent:
    def __init__(self, env_config, args, device):
        self.env_config = env_config
        self.args = args
        self.device = device

        # 初始化全局ActorCritic模型
        self.global_actor_critic = ActorCritic(
            input_dim=4,
            embedding_dim=args.embedding_dim,
            hidden_dim=args.hidden_dim,
            num_nodes_max=env_config['num_nodes'] + 1,
            capacity=env_config['capacity']
        ).to(self.device)
        self.global_actor_critic.share_memory()

        self.optimizer = optim.Adam(self.global_actor_critic.parameters(), lr=args.lr)

        self.global_episode_count = mp.Value('i', 0)
        self.result_queue = mp.Queue()

        self.workers = []
        for i in range(args.num_workers):
            worker = Worker(env_config, self.global_actor_critic, self.optimizer,
                            self.global_episode_count, args.max_episodes,
                            self.result_queue, i, args, self.device)
            self.workers.append(worker)

    def train(self):
        start_time = time.time()
        print(f"在 {self.device} 上开始A3C训练，使用 {self.args.num_workers} 个工作进程...")
        if not hasattr(self.args, 'log_interval'):
            self.args.log_interval = 100

        for worker in self.workers:
            worker.start()

        training_results = []
        pbar = tqdm(total=self.args.max_episodes, desc="训练A3C", unit="回合")

        try:
            last_episode_count = 0
            while self.global_episode_count.value < self.args.max_episodes:
                while not self.result_queue.empty():
                    result = self.result_queue.get_nowait()
                    training_results.append(result)

                current_episode_val = self.global_episode_count.value
                if current_episode_val > last_episode_count:
                    pbar.update(current_episode_val - last_episode_count)
                    last_episode_count = current_episode_val

                    if current_episode_val % self.args.log_interval == 0 or current_episode_val == self.args.max_episodes:
                         training_results.sort(key=lambda x: x['episode'])
                         recent_results = [r for r in training_results if r['episode'] > current_episode_val - self.args.log_interval]

                         if recent_results:
                             recent_rewards = [r['reward'] for r in recent_results]
                             avg_reward_recent = np.mean(recent_rewards) if recent_rewards else 0
                             pbar.set_postfix({'最近平均奖励': f"{avg_reward_recent:.2f}"})
                         else:
                             pbar.set_postfix({'状态': '收集初始数据'})

                time.sleep(0.01)

                if self.global_episode_count.value < self.args.max_episodes and not any(w.is_alive() for w in self.workers):
                    print("\n所有工作进程似乎提前终止了")
                    while not self.result_queue.empty():
                        result = self.result_queue.get_nowait()
                        training_results.append(result)
                    current_episode_val = self.global_episode_count.value
                    if current_episode_val > last_episode_count:
                         pbar.update(current_episode_val - last_episode_count)
                    break

        except KeyboardInterrupt:
            print("\n训练被用户中断")
        finally:
            pbar.n = self.global_episode_count.value
            pbar.refresh()
            pbar.close()

            for worker in self.workers:
                if worker.is_alive():
                    print(f"终止工作进程 {worker.worker_idx}...")
                    worker.terminate()
                worker.join()

        end_time = time.time()
        print(f"训练完成，用时 {end_time - start_time:.2f} 秒")
        training_results.sort(key=lambda x: x['episode'])
        return training_results

    def evaluate(self, num_instances=100, data_generator_eval=None):
        print(f"\n在 {self.device} 上开始评估...")
        self.global_actor_critic.eval()

        eval_instance_results = []

        if data_generator_eval is None:
            data_generator_eval = self.args.data_generator

        for _ in tqdm(range(num_instances), desc="评估中"):
            start_time_instance = time.time()

            coords, demands, capacity, num_customer_nodes = data_generator_eval.generate_instance()
            coords_t = torch.from_numpy(coords).float().unsqueeze(0).to(self.device)
            demands_t = torch.from_numpy(demands).float().unsqueeze(0).to(self.device)
            max_demand_val = demands_t.max() if demands_t.numel() > 0 else torch.tensor(1.0, device=self.device)
            if max_demand_val == 0: max_demand_val = torch.tensor(1.0, device=self.device)

            eval_env = CVRPEnv(num_nodes=num_customer_nodes, capacity=capacity, depot_idx=self.env_config['depot_idx'])
            state = eval_env.reset(coords, demands)

            done = False
            last_action_idx = torch.tensor([eval_env.depot_idx], device=self.device)
            hidden_state = None
            episode_steps = 0
            max_eval_steps = (num_customer_nodes + 1) * 2

            while not done and episode_steps < max_eval_steps:
                model_state_dict = {
                    'coords': coords_t,
                    'demands': demands_t,
                    'current_node': torch.tensor([state['current_node']], device=self.device),
                    'visited_mask': torch.from_numpy(state['visited_mask']).bool().unsqueeze(0).to(self.device),
                    'current_capacity': torch.tensor([state['current_capacity']]).float().unsqueeze(0).to(self.device),
                    'total_distance': torch.tensor([state['total_distance']]).float().unsqueeze(0).to(self.device),
                    'num_vehicles_used': torch.tensor([state['num_vehicles_used']]).float().unsqueeze(0).to(self.device),
                    'depot_idx': self.env_config['depot_idx'],
                    'max_demand_val': max_demand_val
                }
                action_mask_np = eval_env.get_action_mask()
                action_mask = torch.from_numpy(action_mask_np).bool().unsqueeze(0).to(self.device)
                model_state_dict['action_mask'] = action_mask

                with torch.no_grad():
                    logits, _, hidden_state = self.global_actor_critic(
                        model_state_dict,
                        last_action_idx=last_action_idx,
                        hidden_state=hidden_state
                    )

                final_logits = logits.clone()
                final_logits[action_mask.logical_not()] = -float('inf')

                if torch.all(final_logits.isinf()):
                     valid_indices = torch.where(action_mask.squeeze(0))[0]
                     if len(valid_indices) > 0:
                          action = valid_indices[torch.randint(0, len(valid_indices), (1,))].unsqueeze(0)
                          print(f"评估：实例遇到问题状态（所有logits为-inf），选择随机有效动作 {action.item()}")
                     else:
                          print(f"评估：错误！所有logits为-inf且没有有效动作。掩码：{action_mask_np}。强制结束")
                          action = torch.tensor([[self.env_config['depot_idx']]], device=self.device)
                          done = True

                else:
                    action = torch.argmax(final_logits, dim=-1)

                if not done:
                    next_state, reward, done, info = eval_env.step(action.item())
                    state = next_state
                    last_action_idx = action.squeeze(0)
                    episode_steps += 1
                else:
                     break

            if episode_steps >= max_eval_steps and not done:
                 print(f"评估实例达到最大步数 ({max_eval_steps}) 但未完成")

            end_time_instance = time.time()
            computation_time_instance = end_time_instance - start_time_instance

            instance_result = {
                'distance': eval_env.total_distance,
                'vehicles': eval_env.num_vehicles_used,
                'time': computation_time_instance
            }
            eval_instance_results.append(instance_result)

        self.global_actor_critic.train()

        return eval_instance_results
